【GAN】四、CGAN论文详解与代码详解

您所在的位置:网站首页 generator model 【GAN】四、CGAN论文详解与代码详解

【GAN】四、CGAN论文详解与代码详解

#【GAN】四、CGAN论文详解与代码详解| 来源: 网络整理| 查看: 265

前言

自从10月15号在广州的实习结束后,这将近1个月的时间由于学校各种实习相关手续、答辩和赶上毕业论文开题的节奏等原因,因此相关实习结束之前相关笔记没有及时。从今天开始,将恢复相关博客的更新。

在之前我们介绍了DCGAN与原始GAN的相关理论,并给出了DCGAN生成手写数字图像的代码。若有兴趣请分别移步如下链接:

本篇博客我们将介绍CGAN(条件GAN)论文的相关细节。CGAN的论文网址请移步:

CGAN生成手写数字的keras代码请移步:

一、 GAN回顾 \underset{G}{\mathop{\min }}\,\underset{D}{\mathop{\max }}\,V(D,G)={{\mathbb{E}}{x\sim {{p}{data}}(x)}}[\log D(x)]+{{\mathbb{E}}{z\sim {{p}{data}}(z)}}[\log (1-D(G(z)))]\tag1

为了兼顾CGAN的相关理论介绍,我们首先回顾GAN相关细节。GAN主要包括两个网络,一个是生成器 G 和判别器 D ,生成器的目的就是将随机输入的高斯噪声映射成图像(“假图”),判别器则是判断输入图像是否来自生成器的概率,即判断输入图像是否为假图的概率。

在这里我们假设数据为 x ,生成器的数据分布为 p_g ,噪声分布为 p_z(z) ,那么噪声 z 的结果可以记作 G(z;\theta_g) ,数据 x 在判别器 D 上的结果为 D(x;\theta_d) 。

那么GAN的目的就是无中生有,以假乱真。即要使得生成器 G 生成的所谓的"假图"骗过判别器 D ,那么最优状态就是生成器 G 生成的所谓的"假图"在判别器 D 的判别结果为0.5,不知道到底是真图还是假图。GAN的目标函数如下:

\underset{G}{\mathop{\min }}\,\underset{D}{\mathop{\max }}\,V(D,G)={{\mathbb{E}}_{x\sim {{p}{data}}(x)}}[\log D(x)]+{{\mathbb{E}}_{z\sim {{p}{data}}(z)}}[\log (1-D(G(z)))]\tag1

二、CGAN网络架构详解

在介绍CGAN的原理接下来介绍了CGAN的相关原理。原始的GAN的生成器只能根据随机噪声进行生成图像,至于这个图像是什么(即标签是什么我们无从得知),判别器也只能接收图像输入进行判别是否图像来使生成器。因此CGAN的主要贡献就是在原始GAN的生成器与判别器中的输入中加入额外信息$y$。额外信息$y$可以是任何信息,例如标签。因此CGAN的提出使得GAN可以利用图像与对应的标签进行训练,并在测试阶段 利用给定标签生成特定图像。

在CGAN的论文中,网络架构使用的MLP(全连接网络)。在CGAN中的生成器,我们给定一个输入噪声 p_z(z) 和额外信息 y ,之后将两者通过全连接层连接到一起,作为隐藏层输入。同样地,在判别器中输入图像 x 和 额外信息 y 也将连接到一起作为隐藏层输入。CGAN的网络架构图如下所示:

那么,CGAN的目标函数可以表述成如下形式:

\underset{G}{\mathop{\min }}\,\underset{D}{\mathop{\max }}\,V(D,G)={{\mathbb{E}}_{x\sim {{p}{data}}(x)}}[\log D(x|y)]+{{\mathbb{E}}_{z\sim {{p}{data}}(z)}}[\log (1-D(G(z|y)))]\tag2

下面是CGAN论文中生成的手写数字图像的结果,每一行代表有一个标签,例如第一行代表标签为0的图片。

三、CGAN-MNIST代码详解

接下来我们将主要介绍CGAN生成手写数字图像的keras代码。github链接为:

首先给出CGAN的网络架构代码:

# -*- coding: utf-8 -*- # @Time : 2019/10/8 13:39 # @Author : Dai PuWei # @File : CGAN.py # @Software: PyCharm import os import cv2 import numpy as np import datetime import matplotlib.pyplot as plt from scipy.stats import truncnorm from keras import Input from keras import Model from keras import Sequential from keras.layers import Dense from keras.layers import Activation from keras.layers import Reshape from keras.layers import Conv2DTranspose from keras.layers import BatchNormalization from keras.layers import Conv2D from keras.layers import LeakyReLU from keras.layers import Dropout from keras.layers import Flatten from keras.layers.merge import multiply from keras.layers.merge import concatenate from keras.layers.merge import add from keras.layers import Embedding from keras.utils import to_categorical from keras.optimizers import Adam from keras.utils.generic_utils import Progbar from copy import deepcopy from keras.datasets import mnist def make_trainable(net, val): """ Freeze or unfreeze layers """ net.trainable = val for l in net.layers: l.trainable = val class CGAN(object): def __init__(self,config,weight_path=None): """ 这是CGAN的初始化函数 :param config: 参数配置类实例 :param weight_path: 权重文件地址,默认为None """ self.config = config self.build_cgan_model() if weight_path is not None: self.cgan.load_weights(weight_path,by_name=True) def build_cgan_model(self): """ 这是搭建CGAN模型的函数 :return: """ # 初始化输入 self.generator_noise_input = Input(shape=(self.config.generator_noise_input_dim,)) self.condational_label_input = Input(shape=(1,), dtype='int32') self.discriminator_image_input = Input(shape=self.config.discriminator_image_input_dim) # 定义优化器 self.optimizer = Adam(lr=2e-4, beta_1=0.5) # 构建生成器模型与判别器模型 self.discriminator_model = self.build_discriminator_model() self.discriminator_model.compile(optimizer=self.optimizer, loss=['binary_crossentropy'],metrics=['accuracy']) self.generator_model = self.build_generator() # 构建CGAN模型 self.discriminator_model.trainable = False self.cgan_input = [self.generator_noise_input,self.condational_label_input] generator_output = self.generator_model(self.cgan_input) cgan_output = self.discriminator_model([generator_output,self.condational_label_input]) self.cgan = Model(self.cgan_input,cgan_output) # 编译 #self.discriminator_model.compile(optimizer=self.optimizer,loss='binary_crossentropy') self.cgan.compile(optimizer=self.optimizer,loss=['binary_crossentropy']) def build_discriminator_model(self): """ 这是搭建生成器模型的函数 :return: """ model = Sequential() model.add(Dense(512, input_dim=np.prod(self.config.discriminator_image_input_dim))) model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha)) model.add(Dense(512)) model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha)) model.add(Dropout(self.config.LeakyReLU_alpha)) model.add(Dense(512)) model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha)) model.add(Dropout(self.config.LeakyReLU_alpha)) model.add(Dense(1, activation='sigmoid')) model.summary() img = Input(shape=self.config.discriminator_image_input_dim) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.config.condational_label_num, np.prod(self.config.discriminator_image_input_dim))(label)) flat_img = Flatten()(img) model_input = multiply([flat_img, label_embedding]) validity = model(model_input) return Model([img, label], validity) def build_generator(self): """ 这是构建生成器网络的函数 :return:返回生成器模型generotor_model """ model = Sequential() model.add(Dense(256, input_dim=self.config.generator_noise_input_dim)) model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha)) model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum)) model.add(Dense(512)) model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha)) model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha)) model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum)) model.add(Dense(np.prod(self.config.discriminator_image_input_dim), activation='tanh')) model.add(Reshape(self.config.discriminator_image_input_dim)) model.summary() noise = Input(shape=(self.config.generator_noise_input_dim,)) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.config.condational_label_num, self.config.generator_noise_input_dim)(label)) model_input = multiply([noise, label_embedding]) img = model(model_input) return Model([noise, label], img) def train(self, train_datagen, epoch, k, batch_size=256): """ 这是DCGAN的训练函数 :param train_generator:训练数据生成器 :param epoch:周期数 :param batch_size:小批量样本规模 :param k:训练判别器次数 :return: """ time =datetime.datetime.now().strftime("%Y%m%d%H%M%S") model_path = os.path.join(self.config.model_dir,time) if not os.path.exists(model_path): os.mkdir(model_path) train_result_path = os.path.join(self.config.train_result_dir,time) if not os.path.exists(train_result_path): os.mkdir(train_result_path) for ep in np.arange(1, epoch+1).astype(np.int32): cgan_losses = [] d_losses = [] # 生成进度条 length = train_datagen.batch_num progbar = Progbar(length) print('Epoch {}/{}'.format(ep, epoch)) iter = 0 while True: # 遍历一次全部数据集,那么重新来结束while循环 #print("iter:{},{}".format(iter,train_datagen.get_epoch() != ep)) if train_datagen.epoch != ep: break # 获取真实图片,并构造真图对应的标签 batch_real_images, batch_real_labels = train_datagen.next_batch() batch_real_num_labels = np.ones((batch_size, 1)) #batch_real_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1)) # 初始化随机噪声,伪造假图,并合并真图和假图数据集 batch_noises = np.random.normal(0, 1, size = (batch_size, self.config.generator_noise_input_dim)) d_loss = [] for i in np.arange(k): # 构造假图标签,合并真图和假图对应标签 batch_fake_num_labels = np.zeros((batch_size,1)) #batch_fake_num_labels = truncnorm.rvs(0.0, 0.3, size=(batch_size, 1)) batch_fake_labels = deepcopy(batch_real_labels) batch_fake_images = self.generator_model.predict([batch_noises,batch_fake_labels]) # 训练判别器 real_d_loss = self.discriminator_model.train_on_batch([batch_real_images,batch_real_labels], batch_real_num_labels) fake_d_loss = self.discriminator_model.train_on_batch([batch_fake_images, batch_fake_labels], batch_fake_num_labels) d_loss.append(list(0.5*np.add(real_d_loss,fake_d_loss))) #print(d_loss) d_losses.append(list(np.average(d_loss,0))) #print(d_losses) # 生成一个batch_size的噪声来训练生成器 #batch_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1)) batch_num_labels = np.ones((batch_size,1)) batch_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1) cgan_loss = self.cgan.train_on_batch([batch_noises,batch_labels], batch_num_labels) cgan_losses.append(cgan_loss) # 更新进度条 progbar.update(iter, [('dcgan_loss', cgan_losses[iter]), ('discriminator_loss',d_losses[iter][0]), ('acc',d_losses[iter][1])]) #print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (ep, d_losses[ep][0], 100 * d_losses[ep][1],cgan_loss)) iter += 1 if ep % self.config.save_epoch_interval == 0: model_cgan = "Epoch{}dcgan_loss{}discriminator_loss{}acc{}.h5".format(ep, np.average(cgan_losses), np.average(d_losses,0)[0],np.average(d_losses,0)[1]) self.cgan.save(os.path.join(model_path, model_cgan)) save_dir = os.path.join(train_result_path, str("Epoch{}".format(ep))) if not os.path.exists(save_dir): os.mkdir(save_dir) self.save_image(int(ep), save_dir) ''' if int(ep) in self.config.generate_image_interval: save_dir = os.path.join(train_result_path,str("Epoch{}".format(ep))) if not os.path.exists(save_dir): os.mkdir(save_dir) self.save_image(ep,save_dir) ''' plt.plot(np.arange(epoch),cgan_losses,'b-','cgan-loss') plt.plot(np.arange(epoch), d_losses[0], 'b-', 'd-loss') plt.grid(True) plt.legend(locs="best") plt.xlabel("Epoch") plt.ylabel("Loss") plt.savefig(os.path.join(train_result_path,"loss.png")) def save_image(self, epoch,save_path): """ 这是保存生成图片的函数 :param epoch:周期数 :param save_path: 图片保存地址 :return: """ rows, cols = 10, 10 fig, axs = plt.subplots(rows, cols) for i in range(rows): label = np.array([i]*rows).astype(np.int32).reshape(-1,1) noise = np.random.normal(0, 1, (cols, 100)) images = self.generator_model.predict([noise,label]) images = 127.5*images+127.5 cnt = 0 for j in range(cols): #img_path = os.path.join(save_path, str(cnt) + ".png") #cv2.imwrite(img_path, images[cnt]) #axs[i, j].imshow(image.astype(np.int32)[:,:,0]) axs[i, j].imshow(images[cnt,:, :, 0].astype(np.int32), cmap='gray') axs[i, j].axis('off') cnt += 1 fig.savefig(os.path.join(save_path, "mnist-{}.png".format(epoch)), dpi=600) plt.close() def generate_image(self,label): """ 这是伪造一张图片的函数 :param label:标签 """ noise = truncnorm.rvs(-1, 1, size=(1, self.config.generator_noise_input_dim)) label = np.array([label]).T image = self.generator_model.predict([noise,label])[0] image = 127.5*(image+1) return image

为了训练我们必须还的构造一个数据集迭代器来读取小批量手写数字图像数据,数据集迭代器类的代码如下:

# -*- coding: utf-8 -*- # @Time : 2019/10/8 17:29 # @Author : Dai PuWei # @File : MnistGenerator.py # @Software: PyCharm import math import numpy as np from keras.datasets import mnist class MnistGenerator(object): def __init__(self,batch_size): """ 这是图像数据生成器的初始化函数 :param batch_size: 小批量样本规模 """ (x_train,y_train),(x_test,y_test) = mnist.load_data() #self.x = np.concatenate([x_train,x_test]).astype(np.float32) self.x = np.expand_dims((x_train.astype(np.float32)-127.5)/127.5,axis=-1) #self.y = to_categorical(np.concatenate([y_train,y_test]),num_classes=10) self.y = y_train.reshape(-1,1) #self.y = self.y[y == ] #print(np.shape(self.x)) #print(np.shape(self.y)) self.images_size = len(self.x) random_index = np.random.permutation(np.arange(self.images_size)) self.x = self.x[random_index] self.y = self.y[random_index] self.epoch = 1 # 当前迭代次数 self.batch_size = int(batch_size) self.batch_num = math.ceil(self.images_size / self.batch_size) self.start = 0 self.end = 0 self.finish_flag = False # 数据集是否遍历完一次标志 def _next_batch(self): """ :return: """ while True: #batch_images = np.array([]) #batch_labels = np.array([]) if self.finish_flag: # 数据集遍历完一次 random_index = np.random.permutation(np.arange(self.images_size)) self.x = self.x[random_index] self.y = self.y[random_index] self.finish_flag = False self.epoch += 1 self.end = int(np.min([self.images_size,self.start+self.batch_size])) batch_images = self.x[self.start:self.end] batch_labels = self.y[self.start:self.end] batch_size = self.end - self.start if self.end == self.images_size: # 数据集刚分均分 self.finish_flag = True if batch_size

第10个epoch之后的生成结果:

第100个epoch之后的生成结果:

第1000个epoch之后的生成结果:

下面是CGAN的测试代码:

# -*- coding: utf-8 -*- # @Time : 2019/11/8 13:11 # @Author : DaiPuWei # @Email : [email protected] # @File : test.py # @Software: PyCharm import os from CGAN.CGAN import CGAN from Config.Config import MnistConfig def run_main(): """ 这是主函数 """ weight_path = os.path.abspath("./model/20191009134644/Epoch1378dcgan_loss1.5952800512313843discriminator_loss[0.49839333 0.7379193 ]acc[0.49839333 0.7379193 ].h5") result_path = os.path.abspath("./test_result") if not os.path.exists(result_path): os.mkdir(result_path) cfg = MnistConfig() cgan = CGAN(cfg,weight_path) cgan.save_image(0,result_path) if __name__ == '__main__': run_main()

欢迎关注我的微信公众号:AI那点小事



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3